Goto

Collaborating Authors

 test error


On the Mechanisms of Weak-to-Strong Generalization: ATheoretical Perspective

Neural Information Processing Systems

Weak-to-strong generalization--where a student model trained on imperfect labels generated by a weaker teacher nonetheless surpasses that teacher--has been widely observed, but the mechanisms that enable it have remained poorly understood. In this paper, through a theoretical analysis of simple models, we uncover three core mechanisms that can drive this phenomenon. First, by analyzing ridge linear regression, we study the interplay between the teacher and student regularization parameters and prove that a student can compensate for a teacher's under-regularization and achieve lower test error. We also analyze the role of the parameterization regime of the models and show that qualitatively different phenomena can happen in different regimes. Second, by analyzing weighted ridge linear regression, we show that a student model with a regularization structure better aligned to the target function, can outperform its teacher. Third, in a nonlinear multi-index learning setting, we demonstrate that a student can learn easy, task-specific features from the teacher while leveraging its own broader pre-training to learn hard-to-learn features that the teacher cannot capture.


The Nuclear Route: Sharp Asymptotics of ERM in Overparameterized Quadratic Networks

Neural Information Processing Systems

We study the high-dimensional asymptotics of empirical risk minimization (ERM) in over-parametrized two-layer neural networks with quadratic activations trained on synthetic data. We derive sharp asymptotics for both training and test errors by mapping the โ„“2-regularized learning problem to a convex matrix sensing task with nuclear norm penalization. This reveals that capacity control in such networks emerges from a low-rank structure in the learned feature maps. Our results characterize the global minima of the loss and yield precise generalization thresholds, showing how the width of the target function governs learnability. This analysis bridges and extends ideas from spin-glass methods, matrix factorization, and convex optimization and emphasizes the deep link between low-rank matrix sensing and learning in quadratic neural networks.


Dynamical Decoupling of Generalization and Overfitting in Large Two-Layer Networks

Neural Information Processing Systems

Understanding the inductive bias and generalization properties of large overparametrized machine learning models requires to characterize the dynamics of the training algorithm. We study the learning dynamics of large two-layer neural networks via dynamical mean field theory, a well established technique of nonequilibrium statistical physics. We show that, for large network width m, and large number of samples per input dimension n/d, the training dynamics exhibits a separation of timescales which implies: (i) The emergence of a slow time scale associated with the growth in Gaussian/Rademacher complexity of the network; (ii) Inductive bias towards small complexity if the initialization has small enough complexity; (iii) A dynamical decoupling between feature learning and overfitting regimes; (iv)A non-monotone behavior of the test error, associated'feature unlearning' regime at large times.


Understanding the Generalization of Stochastic Gradient Adam in Learning Neural Networks

Neural Information Processing Systems

Adam is a popular and widely used adaptive gradient method in deep learning, which has also received tremendous focus in theoretical research. However, most existing theoretical work primarily analyzes its full-batch version, which differs fundamentally from the stochastic variant used in practice. Unlike SGD, stochastic Adam does not converge to its full-batch counterpart even with infinitesimal learning rates. We present the first theoretical characterization of how batch size affects Adam's generalization, analyzing two-layer over-parameterized CNNs on image data. Our results reveal that while both Adam and AdamW with proper weight decay ฮป converge to poor test error solutions, their mini-batch variants can achieve near-zero test error. We further prove Adam has a strictly smaller effective weight decay bound than AdamW, theoretically explaining why Adam requires more sensitive ฮปtuning.


Decoupled Descent: Exact Test Error Tracking Via Approximate Message Passing

arXiv.org Machine Learning

In modern parametric model training, full-batch gradient descent (and its variants) suffers due to progressively stronger biasing towards the exact realization of training data; this drives the systematic ``generalization gap'', where the train error becomes an unreliable proxy for test error. Existing approaches either argue this gap is benign through complex analysis or sacrifice data to a validation set. In contrast, we introduce decoupled descent (DD), a novel theory-based training algorithm that satisfies a train-test identity -- enforcing the train error to asymptotically track the test error for stylized Gaussian mixture models. Within this specific regime, leveraging approximate message passing theory, DD iteratively cancels the biases due to data reuse, rigorously demonstrating the feasibility of zero-cost validation and $100\%$ data utilization. Moreover, DD is governed by a low-dimensional state evolution recursion, rendering the dynamics of the algorithm transparent and tractable. We validate DD on XOR classification, yielding superior performance compared to GD; additionally, we implement noisy MNIST and non-linear probing of CIFAR-10, demonstrating that even when our stylized assumptions are relaxed, DD narrows the generalization gap compared to GD.




DAMEX: Dataset-aware Mixture-of-Experts for visual understanding of mixture-of-datasets Supplementary Material Anonymous Author(s) Affiliation Address email

Neural Information Processing Systems

Here we provide theoretical evidence that vanilla MoE do not6 guarantee convergence when mixing multiple datasets. Consider a binary classification problem over P-patch inputs where each8 patch has d dimensions and label y = { 1}. Thus, a labeled data point (x,y) has input x =9 (x(1),x(2),x(3),...,x(P)) (Rd)P is a collection of P patch inputs with y as the data label. The10 data x is generated from K clusters.11 Chen et al. [2022] proves that in such a binary-classification problem, an MoE layer converges to an12 o(1) test loss and zero training loss.